Arborenv In-Need Regions ML Usecase
In [1]:
# Common
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from IPython.display import clear_output as cls
# Data
from tqdm import tqdm
import tensorflow.data as tfd
# Data Visualization
import matplotlib.pyplot as plt
# Model Building
from tensorflow.keras import layers
from tensorflow.keras import callbacks
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.keras.optimizers.schedules import ExponentialDecay
# Model visualization
from tensorflow.keras.utils import plot_model
# Extra
from typing import List, Tuple, Union
Hyperparameters and constants¶
In [2]:
# Image and Mask Dimensions
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 160
N_IMAGE_CHANNELS = 3
N_MASK_CHANNELS = 1
# Image and Mask Size
IMAGE_SIZE = (IMAGE_WIDTH, IMAGE_HEIGHT, N_IMAGE_CHANNELS)
MASK_SIZE = (IMAGE_WIDTH, IMAGE_HEIGHT, N_MASK_CHANNELS)
# Batch Size and Learning Rate
BATCH_SIZE = 32
BASE_LR = 1e-2
# Model Name
MODEL_NAME = 'UNetForestSegmentation'
# Model Training
EPOCHS = 100
# Data Paths
ROOT_IMAGE_DIR = 'Forest_Segmented/images'
ROOT_MASK_DIR = 'Forest_Segmented/masks'
METADATA_CSV_PATH = 'Forest_Segmented/meta_data.csv'
# Model Architecture
FILTERS = 32
In [3]:
# Random Seed
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
Utility Functions¶
In [4]:
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[tf.Tensor, tf.Tensor]:
'''
This function takes the file paths of an image and its corresponding mask as input. It first reads the images, then decodes them into tensors,
and resizes them to a standard size. After that, the image and mask tensors are normalized by clipping the pixel values between 0 and 1.
Finally, the function converts the image and mask tensors to the float32 data type and returns them as a tuple.
Arguments :
image_path : The path to the image to be loaded.
mask_path : The path to the mask to be loaded.
Returns :
image : This is the loaded and the processed image.
mask : This is the loaded and the processed mask.
'''
# Read the images
image = tf.io.read_file(filename = image_path)
mask = tf.io.read_file(filename = mask_path)
# Decode the images
image = tf.image.decode_jpeg(contents = image, channels = N_IMAGE_CHANNELS)
mask = tf.image.decode_jpeg(contents = mask, channels = N_MASK_CHANNELS)
# Convert the image to a Tensor
image = tf.image.convert_image_dtype(image = image, dtype = tf.float32)
mask = tf.image.convert_image_dtype(image = mask, dtype = tf.float32)
# Resize the image to the desired dimensions
image = tf.image.resize(images = image, size = (IMAGE_WIDTH, IMAGE_HEIGHT))
mask = tf.image.resize(images = mask, size = (IMAGE_WIDTH, IMAGE_HEIGHT))
# Normalize the image
image = tf.clip_by_value(image, clip_value_min = 0.0, clip_value_max = 1.0)
mask = tf.clip_by_value(mask, clip_value_min = 0.0, clip_value_max = 1.0)
# Final conversion
image = tf.cast(image, dtype = tf.float32)
mask = tf.cast(mask, dtype = tf.float32)
return image, mask
In [5]:
# Load CSV File
metadata = pd.read_csv(METADATA_CSV_PATH)
# Quick look
metadata.head()
Out[5]:
| image | mask | |
|---|---|---|
| 0 | 10452_sat_08.jpg | 10452_mask_08.jpg |
| 1 | 10452_sat_18.jpg | 10452_mask_18.jpg |
| 2 | 111335_sat_00.jpg | 111335_mask_00.jpg |
| 3 | 111335_sat_01.jpg | 111335_mask_01.jpg |
| 4 | 111335_sat_02.jpg | 111335_mask_02.jpg |
In [6]:
# Define indices
start_index = 0
end_index = 1500
# Slice metadata to get the first 2000 entries
metadata_subset = metadata.iloc[start_index:end_index].copy()
# Add root path to image file names and ensure forward slashes
metadata_subset['image'] = [os.path.normpath(os.path.join(ROOT_IMAGE_DIR, filename)).replace('\\', '/') for filename in metadata_subset['image']]
# Add root path to mask file names and ensure forward slashes
metadata_subset['mask'] = [os.path.normpath(os.path.join(ROOT_MASK_DIR, filename)).replace('\\', '/') for filename in metadata_subset['mask']]
In [7]:
# Quick Check
metadata_subset.head()
Out[7]:
| image | mask | |
|---|---|---|
| 0 | Forest_Segmented/images/10452_sat_08.jpg | Forest_Segmented/masks/10452_mask_08.jpg |
| 1 | Forest_Segmented/images/10452_sat_18.jpg | Forest_Segmented/masks/10452_mask_18.jpg |
| 2 | Forest_Segmented/images/111335_sat_00.jpg | Forest_Segmented/masks/111335_mask_00.jpg |
| 3 | Forest_Segmented/images/111335_sat_01.jpg | Forest_Segmented/masks/111335_mask_01.jpg |
| 4 | Forest_Segmented/images/111335_sat_02.jpg | Forest_Segmented/masks/111335_mask_02.jpg |
In [8]:
def load_dataset(
image_paths: list, mask_paths: list, split_ratio: float=0.7,
batch_size: int=BATCH_SIZE, shuffle: bool=True,
buffer_size: int=1000, n_repeat: int=1
) -> Union[Tuple[tfd.Dataset, tfd.Dataset], tfd.Dataset]:
'''
This function loads the image and mask data from the provided file paths and creates a TensorFlow dataset. The function
first creates space to store the image and mask data in numpy arrays. It then iterates over each image and mask pair,
loading them using the load_image_and_mask function and storing them in the numpy arrays.
The function then creates a TensorFlow dataset using the numpy arrays. If shuffle is True, it shuffles the dataset
with a buffer size of buffer_size. If split_ratio is not None, it splits the dataset into two parts with sizes determined
by the split_ratio, and converts them into batches of size batch_size with drop_remainder=True. The two resulting datasets
are returned as a tuple.
If split_ratio is None, the entire dataset is converted into batches of size batch_size with drop_remainder=True,
and the resulting dataset is returned.
Args:
image_paths: A list of strings, containing the file paths of the input images.
mask_paths: A list of strings, containing the file paths of the corresponding mask images.
split_ratio: A float value between 0 and 1, representing the ratio of data to be used for validation.
If split_ratio is set to None, then no data will be split for validation.
batch_size: An integer, representing the batch size for the input data.
shuffle: A boolean value indicating whether the data should be shuffled or not.
buffer_size: An integer, representing the buffer size for shuffling the data.
n_repeat: An integer, representing the total number of repetations of the data.
Returns:
If split_ratio is not None, then the function returns a tuple of two Tensorflow datasets.
The first dataset contains the training data and the second dataset contains the validation data.
If split_ratio is None, then the function returns a single Tensorflow dataset containing the
input data batched and pre-fetched for training.
'''
# Create space for storing the data.
images = np.empty(shape=(len(image_paths), *IMAGE_SIZE), dtype=np.float32)
masks = np.empty(shape=(len(mask_paths), *MASK_SIZE), dtype=np.float32)
# Iterate over the data.
index = 0
for image_path, mask_path in tqdm(zip(image_paths, mask_paths), desc='Loading'):
# Load the image and the mask.
image, mask = load_image_and_mask(image_path = image_path, mask_path = mask_path)
# Store the image and the mask.
images[index] = image
masks[index] = mask
# Increment the index.
index += 1
# Create a Tensorflow data.
data_set = tfd.Dataset.from_tensor_slices((images, masks)).repeat(n_repeat)
# Shuffle the data set.
if shuffle:
data_set = data_set.shuffle(buffer_size)
# Split the data
if split_ratio is not None:
split_ratio_val_test = (1 - split_ratio)/2
# Calculate new data sizes after splitting.
data_1_len = int(split_ratio * len(images))
print(data_1_len)
data_2_len = int(split_ratio_val_test * len(images))
print(data_2_len)
# Divide the data into 2 parts.
data_1 = data_set.take(data_1_len)
data_2 = data_set.skip(data_1_len).take(data_2_len)
data_3 = data_set.skip(data_1_len + data_2_len).take(data_2_len)
# Convert data into batches.
data_1 = data_1.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
data_2 = data_2.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
data_3 = data_3.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
# Return the data
return data_1, data_2, data_3
else:
# Convert data into batches
data_set = data_set.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
# Return the data
return data_set
In [9]:
# Training and Testing Data
train_ds, test_ds, valid_ds = load_dataset(
image_paths = metadata_subset['image'],
mask_paths = metadata_subset['mask'],
split_ratio = 0.7,
shuffle = True,
n_repeat=3,
)
Loading: 1500it [00:03, 498.11it/s]
1050 225
In [10]:
print("*"*100)
print(f"{' '*30}Training Data Size : {train_ds.cardinality().numpy() * BATCH_SIZE}")
print(f"{' '*30}Testing Data Size : {test_ds.cardinality().numpy() * BATCH_SIZE}")
print(f"{' '*30}Validation Data Size : {valid_ds.cardinality().numpy() * BATCH_SIZE}")
print("*"*100)
****************************************************************************************************
Training Data Size : 1024
Testing Data Size : 224
Validation Data Size : 224
****************************************************************************************************
In [11]:
# # Training Data size
# full_train_size = full_train_ds.cardinality().numpy()
# # Split Ratio
# train_val_split = 0.1
# valid_size = int(full_train_size * train_val_split)
# train_size = full_train_size - valid_size
# # Split Data
# train_ds = full_train_ds.take(train_size)
# valid_ds = full_train_ds.skip(train_size).take(valid_size)
In [12]:
train_ds.cardinality().numpy(),valid_ds.cardinality().numpy(),test_ds.cardinality().numpy()
Out[12]:
(32, 7, 7)
In [13]:
# print("*"*100)
# print(f"{' '*30}Training Data Size : {train_ds.cardinality().numpy() * BATCH_SIZE}")
# print(f"{' '*30}Validation Data Size : {valid_ds.cardinality().numpy() * BATCH_SIZE}")
# print(f"{' '*30}Testing Data Size : {test_ds.cardinality().numpy() * BATCH_SIZE}")
# print("*"*100)
Data Visualization¶
In [14]:
def show_images_and_masks(data : tfd.Dataset, n_images: int=10, FIGSIZE: tuple=(25, 5), model: tf.keras.Model=None):
# Configuration
if model is None:
n_cols = 3
else:
n_cols = 5
# Collect the data
images, masks = next(iter(data))
# Iterate over the data
for n in range(n_images):
# Plotting configuration
plt.figure(figsize=FIGSIZE)
# Plot the image
plt.subplot(1, n_cols, 1)
plt.title("Original Image")
plt.imshow(images[n])
plt.axis('off')
# Plot the Mask
plt.subplot(1, n_cols, 2)
plt.title("Original Mask")
plt.imshow(masks[n], cmap='gray')
plt.axis('off')
# Plot image and mask overlay
plt.subplot(1, n_cols, 3)
plt.title('Image and Mask overlay')
plt.imshow(masks[n], alpha=0.8, cmap='binary_r')
plt.imshow(images[n], alpha=0.5)
plt.axis('off')
# Model predictions
if model is not None:
pred_mask = model.predict(tf.expand_dims(images[n], axis=0))[0]
pred_mask = pred_mask>=0.5 # threshold = 0.5
plt.subplot(1, n_cols, 4)
plt.title('Predicted Mask')
plt.imshow(pred_mask, cmap='gray')
plt.axis('off')
plt.subplot(1, n_cols, 5)
plt.title('Predicted Mask Overlay')
plt.imshow(pred_mask, alpha=0.8, cmap='binary_r')
plt.imshow(images[n], alpha=0.5)
plt.axis('off')
# Show final plot
plt.show()
show_images_and_masks(data=train_ds)
U-Net¶
- Image segmentation model
- Encoder and Decoder with skip connection
- Total: Encoder -16 layers, Decoder - 16 layers
- Parametes:
Unet - Encoder Block¶
In [15]:
class EncoderBlock(layers.Layer):
def __init__(self, filters: int, max_pool: bool=True, rate=0.2, **kwargs) -> None:
super().__init__(**kwargs)
# Params
self.rate = rate
self.filters = filters
self.max_pool = max_pool
# Layers : Initialize the model layers that will be later called
self.max_pooling = layers.MaxPool2D(pool_size=(2,2), strides=(2,2))
self.conv1 = layers.Conv2D(
filters=filters,
kernel_size=3,
strides=1,
padding='same',
activation='relu',
kernel_initializer='he_normal'
)
self.conv2 = layers.Conv2D(
filters=filters,
kernel_size=3,
strides=1,
padding='same',
activation='relu',
kernel_initializer='he_normal'
)
self.drop = layers.Dropout(rate)
self.bn = layers.BatchNormalization()
def call(self, X, **kwargs):
X = self.bn(X) # BatchNomlarization
X = self.conv1(X)
X = self.drop(X)
X = self.conv2(X)
# Apply Max Pooling if required
if self.max_pool:
y = self.max_pooling(X)
return y, X
else:
return X
def get_config(self):
config = super().get_config()
config.update({
'filters': self.filters,
'max_pool': self.max_pool,
'rate': self.rate
})
def __repr__(self):
return f"{self.__class__.name}(F={self.filters}, Pooling={self.max_pool})"
UNet - Decoder Block¶
In [16]:
class DecoderBlock(layers.Layer):
def __init__(self, filters: int, rate: float = 0.2, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.rate = rate
# Initialize the model layers
self.convT = layers.Conv2DTranspose(
filters = filters,
kernel_size = 3,
strides = 2,
padding = 'same',
activation = 'relu',
kernel_initializer = 'he_normal'
)
self.bn = layers.BatchNormalization()
self.net = EncoderBlock(filters = filters, rate = rate, max_pool = False)
def call(self, inputs, **kwargs):
# Get both the inputs
X, skip_X = inputs
# Up-sample the skip connection
X = self.bn(X)
X = self.convT(X)
# Concatenate both inputs
X = layers.Concatenate(axis=-1)([X, skip_X])
X = self.net(X)
return X
def get_config(self):
config = super().get_config()
config.update({
'filters': self.filters,
'rate': self.rate,
})
return config
def __repr__(self):
return f"{self.__class__.__name__}(F={self.filters}, rate={self.rate})"
UNet - Encoder Decoder Net¶
In [17]:
# Input Layer
input_layer = layers.Input(shape=(IMAGE_SIZE), name="InputLayer")
# The encoder network
pool1, encoder1 = EncoderBlock(FILTERS, max_pool=True, rate=0.1, name="EncoderLayer1")(input_layer)
pool2, encoder2 = EncoderBlock(FILTERS*2, max_pool=True, rate=0.1, name="EncoderLayer2")(pool1)
pool3, encoder3 = EncoderBlock(FILTERS*4, max_pool=True, rate=0.2, name="EncoderLayer3")(pool2)
pool4, encoder4 = EncoderBlock(FILTERS*8, max_pool=True, rate=0.2, name="EncoderLayer4")(pool3)
# The encoder encoding
encoding = EncoderBlock(FILTERS*16, max_pool=False, rate=0.3, name="EncodingSpace")(pool4)
# The decoder network
decoder4 = DecoderBlock(FILTERS*8, rate=0.2, name="DecoderLayer1")([encoding, encoder4])
decoder3 = DecoderBlock(FILTERS*4, rate=0.2, name="DecoderLayer2")([decoder4, encoder3])
decoder2 = DecoderBlock(FILTERS*2, rate=0.1, name="DecoderLayer3")([decoder3, encoder2])
decoder1 = DecoderBlock(FILTERS, rate=0.1, name="DecoderLayer4")([decoder2, encoder1])
# Final output layer.
final_conv = layers.Conv2D(
filters = 1,
kernel_size = 1,
strides=1,
padding='same',
activation='sigmoid',
name="OutputMap"
)(decoder1)
# Unet Model
unet_model = keras.Model(
inputs = input_layer,
outputs = final_conv,
name = "UNetModel"
)
WARNING:tensorflow:From C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\backend\tensorflow\core.py:192: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.
C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'EncodingSpace', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method. warnings.warn( C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method. warnings.warn( C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block_1', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method. warnings.warn( C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block_2', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method. warnings.warn( C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block_3', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method. warnings.warn(
In [18]:
# Model Summary
unet_model.summary()
Model: "UNetModel"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ InputLayer │ (None, 160, 160, │ 0 │ - │ │ (InputLayer) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ EncoderLayer1 │ [(None, 80, 80, │ 10,156 │ InputLayer[0][0] │ │ (EncoderBlock) │ 32), (None, 160, │ │ │ │ │ 160, 32)] │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ EncoderLayer2 │ [(None, 40, 40, │ 55,552 │ EncoderLayer1[0]… │ │ (EncoderBlock) │ 64), (None, 80, │ │ │ │ │ 80, 64)] │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ EncoderLayer3 │ [(None, 20, 20, │ 221,696 │ EncoderLayer2[0]… │ │ (EncoderBlock) │ 128), (None, 40, │ │ │ │ │ 40, 128)] │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ EncoderLayer4 │ [(None, 10, 10, │ 885,760 │ EncoderLayer3[0]… │ │ (EncoderBlock) │ 256), (None, 20, │ │ │ │ │ 20, 256)] │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ EncodingSpace │ (None, 10, 10, │ 3,540,992 │ EncoderLayer4[0]… │ │ (EncoderBlock) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ DecoderLayer1 │ (None, 20, 20, │ 2,953,984 │ EncodingSpace[0]… │ │ (DecoderBlock) │ 256) │ │ EncoderLayer4[0]… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ DecoderLayer2 │ (None, 40, 40, │ 739,712 │ DecoderLayer1[0]… │ │ (DecoderBlock) │ 128) │ │ EncoderLayer3[0]… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ DecoderLayer3 │ (None, 80, 80, │ 185,536 │ DecoderLayer2[0]… │ │ (DecoderBlock) │ 64) │ │ EncoderLayer2[0]… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ DecoderLayer4 │ (None, 160, 160, │ 46,688 │ DecoderLayer3[0]… │ │ (DecoderBlock) │ 32) │ │ EncoderLayer1[0]… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ OutputMap (Conv2D) │ (None, 160, 160, │ 33 │ DecoderLayer4[0]… │ │ │ 1) │ │ │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 8,640,109 (32.96 MB)
Trainable params: 8,635,303 (32.94 MB)
Non-trainable params: 4,806 (18.77 KB)
In [19]:
# Inspect trainable and non-trainable parameters
for layer in unet_model.layers:
print(f"Layer: {layer.name}")
print(f" Trainable: {layer.trainable}")
print(f" Non-trainable weights: {len(layer.non_trainable_weights)}")
print(f" Trainable weights: {len(layer.trainable_weights)}")
Layer: InputLayer Trainable: True Non-trainable weights: 0 Trainable weights: 0 Layer: EncoderLayer1 Trainable: True Non-trainable weights: 2 Trainable weights: 6 Layer: EncoderLayer2 Trainable: True Non-trainable weights: 2 Trainable weights: 6 Layer: EncoderLayer3 Trainable: True Non-trainable weights: 2 Trainable weights: 6 Layer: EncoderLayer4 Trainable: True Non-trainable weights: 2 Trainable weights: 6 Layer: EncodingSpace Trainable: True Non-trainable weights: 2 Trainable weights: 6 Layer: DecoderLayer1 Trainable: True Non-trainable weights: 4 Trainable weights: 10 Layer: DecoderLayer2 Trainable: True Non-trainable weights: 4 Trainable weights: 10 Layer: DecoderLayer3 Trainable: True Non-trainable weights: 4 Trainable weights: 10 Layer: DecoderLayer4 Trainable: True Non-trainable weights: 4 Trainable weights: 10 Layer: OutputMap Trainable: True Non-trainable weights: 0 Trainable weights: 2
UNet - Model Training¶
In [20]:
class ShowProgress(callbacks.Callback):
"""A callback that displays the original image, the original mask,
the predicted mask, and the Grad-CAM visualization for a sample image
after each epoch of training.
Args:
data (tf.data.Dataset): A dataset of image-mask pairs.
layer_name (str): The name of the layer to use for Grad-CAM.
cmap (str, optional): The colormap to use for displaying the masks.
Defaults to 'gray'.
output_dir (str, optional): The directory to save the output images.
If None, the images will not be saved. Defaults to None.
num_images (int, optional): The number of images to display.
Defaults to 1.
file_format (str, optional): The format to save the output images in.
Defaults to 'png'.
"""
def __init__(self, data: tf.data.Dataset, layer_name: str, cmap: str = 'gray',
output_dir: str = None, num_images: int = 1, file_format: str = 'png',
**kwargs):
super().__init__(**kwargs)
# Validate inputs
if not isinstance(data, tf.data.Dataset):
raise ValueError('The `data` parameter must be a tf.data.Dataset.')
if not isinstance(layer_name, str):
raise ValueError('The `layer_name` parameter must be a string.')
if not isinstance(num_images, int) or num_images < 1:
raise ValueError('The `num_images` parameter must be an integer greater than 0.')
if file_format not in ['png', 'jpg', 'pdf']:
raise ValueError('The `file_format` parameter must be "png", "jpg", or "pdf".')
self.data = data
self.layer_name = layer_name
self.cmap = cmap
self.output_dir = output_dir
self.num_images = num_images
self.file_format = file_format
def on_epoch_end(self, epoch, logs=None):
"""Displays the original image, the original mask, the predicted mask,
and the Grad-CAM visualization for a sample image.
"""
# Plotting configuration
plt.figure(figsize=(25, 8 * self.num_images))
for i in range(self.num_images):
# Get Data
images, masks = next(iter(self.data))
images = images.numpy()
masks = masks.numpy()
# Select image
index = np.random.randint(len(images))
image, mask = images[index], masks[index]
# Make Prediction
pred_mask = self.model.predict(np.expand_dims(image, axis=0))[0]
# Show Image
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image)
plt.axis('off')
# Show Mask
plt.subplot(1, 3, 2)
plt.title("Original Mask")
plt.imshow(mask, cmap=self.cmap)
plt.axis('off')
# Show Model Pred
plt.subplot(1, 3, 3)
plt.title("Predicted Mask")
plt.imshow(pred_mask, cmap=self.cmap)
plt.axis('off')
# Save figure
if self.output_dir is not None:
path = os.path.join(os.curdir, self.output_dir)
plt.savefig(f'Epoch({epoch+1})-Viz.{self.file_format}')
# Show Final plot
plt.show()
In [21]:
test_images, test_masks = next(iter(test_ds))
CALLBACKS = [
callbacks.EarlyStopping(
patience = 10,
restore_best_weights = True),
# callbacks.ModelCheckpoint(
# MODEL_NAME + '.h5',
# save_best_only = True),
ShowProgress(
data = valid_ds,
layer_name = "DecoderLayer4"
)
]
In [22]:
def dice_coeff(y_true: tf.Tensor, y_pred: tf.Tensor, smooth: float=1.0) -> tf.Tensor:
"""Compute the Dice coefficient between predicted and true masks.
Args:
y_true (tf.Tensor): True masks. Shape (batch_size, height, width, num_channels).
y_pred (tf.Tensor): Predicted masks. Shape (batch_size, height, width, num_channels).
smooth (float): Smoothing factor to avoid division by zero.
Returns:
tf.Tensor: Dice coefficient score.
"""
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
dice = tf.reduce_mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
return tf.cast(dice, tf.float32)
In [23]:
# Pixel Accuracy
pixel_acc = metrics.Accuracy(name="PixelAccuracy")
# Mean Intersection Over Union
mean_iou = metrics.MeanIoU(num_classes=2, name="MeanIoU")
# Exponential learning rate decay
'''
For example,
At Step 0: Learning Rate = 0.001
At Step 500: Learning Rate = 0.001 * 0.96 = 0.00096
At Step 1000: Learning Rate = 0.00096 * 0.96 = 0.0009216
And so on...
'''
initial_learning_rate = BASE_LR
decay_steps = 500 # learning rate will be updated everey 500 steps,
decay_rate = 0.96 # learning rate = learning rate * decay_reate, so lr will decrease 4%
lr_schedule = ExponentialDecay(
initial_learning_rate,
decay_steps,
decay_rate,
staircase=True # learnning rate will drop like step, if staircase=False, it will drop smoothly.
)
optimizer = optimizers.Adam(learning_rate=lr_schedule)
# Compile Model
unet_model.compile(
loss = 'binary_crossentropy',
optimizer = optimizer,
metrics = [
pixel_acc,
mean_iou,
dice_coeff
]
)
In [24]:
# Model Training
unet_model_history = unet_model.fit(
train_ds,
validation_data = valid_ds,
epochs = EPOCHS,
callbacks = CALLBACKS,
batch_size = BATCH_SIZE,
)
Epoch 1/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 341ms/step- MeanIoU: 0.1953 - PixelAccuracy: 7.7410e-07 - dice_coeff: 0.5757 - loss: 0.73
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.1955 - PixelAccuracy: 7.5642e-07 - dice_coeff: 0.5764 - loss: 0.7291 - val_MeanIoU: 0.2000 - val_PixelAccuracy: 0.3758 - val_dice_coeff: 0.0024 - val_loss: 2211.1560 Epoch 2/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2037 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6165 - loss: 0.536
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2038 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6165 - loss: 0.5363 - val_MeanIoU: 0.2040 - val_PixelAccuracy: 0.3475 - val_dice_coeff: 0.0017 - val_loss: 158.6864 Epoch 3/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2136 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6170 - loss: 0.510
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2134 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6172 - loss: 0.5109 - val_MeanIoU: 0.2069 - val_PixelAccuracy: 0.3203 - val_dice_coeff: 0.0035 - val_loss: 78.7073 Epoch 4/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - MeanIoU: 0.1999 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6504 - loss: 0.478
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2000 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6503 - loss: 0.4788 - val_MeanIoU: 0.2123 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0155 - val_loss: 25.3732 Epoch 5/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - MeanIoU: 0.1975 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6645 - loss: 0.465
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.1975 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6644 - loss: 0.4655 - val_MeanIoU: 0.2029 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0423 - val_loss: 19.0655 Epoch 6/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - MeanIoU: 0.2029 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6370 - loss: 0.470
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2029 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6371 - loss: 0.4714 - val_MeanIoU: 0.2169 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0794 - val_loss: 2.7798 Epoch 7/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2124 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6254 - loss: 0.474
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2122 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6258 - loss: 0.4750 - val_MeanIoU: 0.1785 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0153 - val_loss: 14.4787 Epoch 8/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2053 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6446 - loss: 0.458
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2052 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6445 - loss: 0.4592 - val_MeanIoU: 0.1998 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0470 - val_loss: 4.3547 Epoch 9/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2048 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6372 - loss: 0.466
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2050 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6370 - loss: 0.4672 - val_MeanIoU: 0.1840 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.1318 - val_loss: 2.5210 Epoch 10/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2081 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6429 - loss: 0.448
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2080 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6432 - loss: 0.4483 - val_MeanIoU: 0.2063 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.2170 - val_loss: 2.6700 Epoch 11/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2113 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6418 - loss: 0.465
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2112 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6419 - loss: 0.4658 - val_MeanIoU: 0.2040 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5801 - val_loss: 0.6076 Epoch 12/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2036 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6475 - loss: 0.450
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2035 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6479 - loss: 0.4503 - val_MeanIoU: 0.2094 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.2787 - val_loss: 2.6653 Epoch 13/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2068 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6562 - loss: 0.449
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2068 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6561 - loss: 0.4489 - val_MeanIoU: 0.2000 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4756 - val_loss: 1.1738 Epoch 14/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2105 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6632 - loss: 0.421
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2105 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6629 - loss: 0.4217 - val_MeanIoU: 0.1981 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.2818 - val_loss: 1.7027 Epoch 15/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2014 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6603 - loss: 0.448
32/32 ━━━━━━━━━━━━━━━━━━━━ 67s 2s/step - MeanIoU: 0.2015 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6602 - loss: 0.4490 - val_MeanIoU: 0.2160 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4432 - val_loss: 1.3020 Epoch 16/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2137 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6348 - loss: 0.457
32/32 ━━━━━━━━━━━━━━━━━━━━ 67s 2s/step - MeanIoU: 0.2133 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6357 - loss: 0.4576 - val_MeanIoU: 0.2224 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5913 - val_loss: 0.7013 Epoch 17/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2055 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6704 - loss: 0.423
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2057 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6702 - loss: 0.4240 - val_MeanIoU: 0.2235 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.3283 - val_loss: 1.2856 Epoch 18/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - MeanIoU: 0.2026 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6732 - loss: 0.402
32/32 ━━━━━━━━━━━━━━━━━━━━ 68s 2s/step - MeanIoU: 0.2026 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6733 - loss: 0.4029 - val_MeanIoU: 0.1886 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4122 - val_loss: 1.3875 Epoch 19/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - MeanIoU: 0.2131 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6556 - loss: 0.426
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2129 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6562 - loss: 0.4263 - val_MeanIoU: 0.2155 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5022 - val_loss: 0.8983 Epoch 20/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2023 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6678 - loss: 0.416
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2025 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6679 - loss: 0.4163 - val_MeanIoU: 0.2131 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6230 - val_loss: 0.5877 Epoch 21/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2027 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6764 - loss: 0.405
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2028 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6762 - loss: 0.4058 - val_MeanIoU: 0.2200 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4770 - val_loss: 0.8206 Epoch 22/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2135 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6626 - loss: 0.416
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2131 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6630 - loss: 0.4163 - val_MeanIoU: 0.1916 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7351 - val_loss: 0.5219 Epoch 23/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - MeanIoU: 0.2094 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6714 - loss: 0.411
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2093 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6715 - loss: 0.4117 - val_MeanIoU: 0.1996 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7330 - val_loss: 0.6408 Epoch 24/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - MeanIoU: 0.2061 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6814 - loss: 0.399
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2062 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6810 - loss: 0.4002 - val_MeanIoU: 0.1916 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6788 - val_loss: 0.4418 Epoch 25/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - MeanIoU: 0.2006 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6744 - loss: 0.417
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2008 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6743 - loss: 0.4173 - val_MeanIoU: 0.2134 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5442 - val_loss: 0.5786 Epoch 26/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - MeanIoU: 0.2001 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6675 - loss: 0.425
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2001 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6678 - loss: 0.4251 - val_MeanIoU: 0.2055 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6549 - val_loss: 0.4363 Epoch 27/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - MeanIoU: 0.2041 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6863 - loss: 0.386
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2041 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6861 - loss: 0.3869 - val_MeanIoU: 0.2087 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5097 - val_loss: 0.7656 Epoch 28/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - MeanIoU: 0.1990 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6716 - loss: 0.393
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.1991 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6715 - loss: 0.3941 - val_MeanIoU: 0.2002 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6198 - val_loss: 0.4951 Epoch 29/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - MeanIoU: 0.2024 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6895 - loss: 0.379
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2024 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6892 - loss: 0.3796 - val_MeanIoU: 0.2229 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6704 - val_loss: 0.4801 Epoch 30/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - MeanIoU: 0.2027 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6837 - loss: 0.393
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2027 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6836 - loss: 0.3935 - val_MeanIoU: 0.1855 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6057 - val_loss: 0.5543 Epoch 31/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.1981 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6964 - loss: 0.386
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1981 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6963 - loss: 0.3861 - val_MeanIoU: 0.1994 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7256 - val_loss: 0.4344 Epoch 32/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - MeanIoU: 0.2029 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6942 - loss: 0.389
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2028 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6941 - loss: 0.3901 - val_MeanIoU: 0.2168 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6799 - val_loss: 0.4103 Epoch 33/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - MeanIoU: 0.2034 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6679 - loss: 0.399
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2035 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6681 - loss: 0.3993 - val_MeanIoU: 0.2016 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7077 - val_loss: 0.4347 Epoch 34/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.2018 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6803 - loss: 0.400
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2018 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6804 - loss: 0.4009 - val_MeanIoU: 0.2158 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6156 - val_loss: 0.4675 Epoch 35/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - MeanIoU: 0.2028 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6906 - loss: 0.373
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2030 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6902 - loss: 0.3739 - val_MeanIoU: 0.2052 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7184 - val_loss: 0.3623 Epoch 36/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2006 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7026 - loss: 0.363
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2007 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7023 - loss: 0.3637 - val_MeanIoU: 0.2227 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6607 - val_loss: 0.3603 Epoch 37/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - MeanIoU: 0.1988 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6915 - loss: 0.381
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1989 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6914 - loss: 0.3816 - val_MeanIoU: 0.2118 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6805 - val_loss: 0.3887 Epoch 38/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.1990 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7049 - loss: 0.361
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1992 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7044 - loss: 0.3625 - val_MeanIoU: 0.2003 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7084 - val_loss: 0.4957 Epoch 39/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2047 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6970 - loss: 0.366
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2047 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6970 - loss: 0.3665 - val_MeanIoU: 0.1988 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6954 - val_loss: 0.3742 Epoch 40/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2096 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6706 - loss: 0.401
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2094 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6710 - loss: 0.4015 - val_MeanIoU: 0.2065 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7359 - val_loss: 0.5942 Epoch 41/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.2073 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6944 - loss: 0.357
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2072 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6944 - loss: 0.3576 - val_MeanIoU: 0.2351 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6398 - val_loss: 0.4517 Epoch 42/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - MeanIoU: 0.2052 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6878 - loss: 0.369
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2053 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6874 - loss: 0.3703 - val_MeanIoU: 0.2146 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6190 - val_loss: 0.4801 Epoch 43/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - MeanIoU: 0.2019 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6970 - loss: 0.374
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2020 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6968 - loss: 0.3748 - val_MeanIoU: 0.2058 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7089 - val_loss: 0.3656 Epoch 44/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.2056 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6960 - loss: 0.357
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2055 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6963 - loss: 0.3576 - val_MeanIoU: 0.2261 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6576 - val_loss: 0.4152 Epoch 45/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2048 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6932 - loss: 0.357
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2049 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6930 - loss: 0.3579 - val_MeanIoU: 0.2119 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6908 - val_loss: 0.4689 Epoch 46/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - MeanIoU: 0.1956 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7229 - loss: 0.348
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1959 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7220 - loss: 0.3494 - val_MeanIoU: 0.2000 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7275 - val_loss: 0.3793
Model Learning Curve¶
In [25]:
# Model History
history = unet_model_history.history
history
Out[25]:
{'MeanIoU': [0.1992032825946808,
0.20723021030426025,
0.2089681178331375,
0.20103561878204346,
0.19850480556488037,
0.20233049988746643,
0.20658963918685913,
0.2036864310503006,
0.20984354615211487,
0.2046331912279129,
0.20716583728790283,
0.20179182291030884,
0.20636393129825592,
0.2093118131160736,
0.20593519508838654,
0.2030024528503418,
0.21003907918930054,
0.20242716372013092,
0.20756644010543823,
0.20600050687789917,
0.20601294934749603,
0.20151656866073608,
0.20484960079193115,
0.20811578631401062,
0.20776106417179108,
0.19975927472114563,
0.20528320968151093,
0.2022167146205902,
0.20413976907730103,
0.20312047004699707,
0.20042486488819122,
0.2013271152973175,
0.20636247098445892,
0.20346450805664062,
0.20794083178043365,
0.20383527874946594,
0.20293761789798737,
0.203046053647995,
0.2042224109172821,
0.20371179282665253,
0.20469354093074799,
0.2084343433380127,
0.2064168006181717,
0.20335863530635834,
0.20726317167282104,
0.20414483547210693],
'PixelAccuracy': [1.9073485191256623e-07,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
'dice_coeff': [0.5992616415023804,
0.6143656373023987,
0.6241587400436401,
0.647817075252533,
0.6599338054656982,
0.6411471366882324,
0.6401103734970093,
0.6434901356697083,
0.6323947906494141,
0.6539073586463928,
0.6431247591972351,
0.6608415842056274,
0.6543480157852173,
0.6542718410491943,
0.6568323969841003,
0.6639542579650879,
0.6646468043327332,
0.6780006289482117,
0.6735907196998596,
0.6692067384719849,
0.671150267124176,
0.6757012009620667,
0.6747865080833435,
0.6676671504974365,
0.6718212366104126,
0.6775573492050171,
0.6802664995193481,
0.6669780611991882,
0.6803557872772217,
0.6801664233207703,
0.6923991441726685,
0.6913303136825562,
0.6717274188995361,
0.6860343813896179,
0.6770738959312439,
0.6927362680435181,
0.6867203116416931,
0.6892097592353821,
0.696695864200592,
0.6829046607017517,
0.6935202479362488,
0.6724157929420471,
0.6902781128883362,
0.7051346898078918,
0.6867726445198059,
0.6933310031890869],
'loss': [0.6324572563171387,
0.5349217653274536,
0.5111037492752075,
0.4871182441711426,
0.4752787947654724,
0.4861675202846527,
0.4822116494178772,
0.4696265757083893,
0.47592490911483765,
0.4478815197944641,
0.4760293960571289,
0.4480341970920563,
0.4450124204158783,
0.44162890315055847,
0.45347458124160767,
0.44626083970069885,
0.42650797963142395,
0.40765997767448425,
0.40924662351608276,
0.4210212826728821,
0.40673133730888367,
0.41711413860321045,
0.4102461338043213,
0.4149117171764374,
0.421040415763855,
0.4233340620994568,
0.39630216360092163,
0.424282968044281,
0.39938491582870483,
0.4032699167728424,
0.38839882612228394,
0.3984592854976654,
0.401938796043396,
0.398421049118042,
0.3912218511104584,
0.38229379057884216,
0.38274428248405457,
0.38439831137657166,
0.37204596400260925,
0.4020407497882843,
0.37592634558677673,
0.40169450640678406,
0.3803367614746094,
0.35325393080711365,
0.3741324841976166,
0.3742380738258362],
'val_MeanIoU': [0.19999520480632782,
0.20395812392234802,
0.20691031217575073,
0.21234601736068726,
0.20294633507728577,
0.21694649755954742,
0.17849156260490417,
0.19981227815151215,
0.18399962782859802,
0.20628155767917633,
0.20395368337631226,
0.20939287543296814,
0.20003190636634827,
0.198095440864563,
0.2160402089357376,
0.22242309153079987,
0.22348520159721375,
0.1886315941810608,
0.21554818749427795,
0.21313956379890442,
0.22003304958343506,
0.19160348176956177,
0.1996234953403473,
0.19158752262592316,
0.21338379383087158,
0.20552490651607513,
0.20871075987815857,
0.20019182562828064,
0.22285757958889008,
0.18552769720554352,
0.19937221705913544,
0.216833233833313,
0.2015785425901413,
0.21578744053840637,
0.20519505441188812,
0.22265677154064178,
0.2117667943239212,
0.20026306807994843,
0.19879455864429474,
0.2064943164587021,
0.23507864773273468,
0.214645653963089,
0.20577462017536163,
0.22608904540538788,
0.2119479775428772,
0.1999789923429489],
'val_PixelAccuracy': [0.3757653534412384,
0.3474501073360443,
0.3203323781490326,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0],
'val_dice_coeff': [0.002381600672379136,
0.0017498626839369535,
0.0034958492033183575,
0.015481376089155674,
0.042307049036026,
0.07937664538621902,
0.015303398482501507,
0.04697667807340622,
0.13179059326648712,
0.21703331172466278,
0.5800686478614807,
0.2786713242530823,
0.47563308477401733,
0.28178802132606506,
0.4432109296321869,
0.5913054347038269,
0.3282725512981415,
0.41223838925361633,
0.5021713376045227,
0.6229732632637024,
0.4769725501537323,
0.7351338267326355,
0.7330007553100586,
0.678837239742279,
0.5442424416542053,
0.6548649668693542,
0.5096802115440369,
0.6198228597640991,
0.6704198122024536,
0.6056585907936096,
0.7256316542625427,
0.6798519492149353,
0.70768803358078,
0.615635871887207,
0.7184128761291504,
0.6607317924499512,
0.6804813742637634,
0.7083863019943237,
0.6953719258308411,
0.7359205484390259,
0.6398470997810364,
0.6189655065536499,
0.7089042067527771,
0.6576145887374878,
0.690753161907196,
0.727544903755188],
'val_loss': [2211.156005859375,
158.6863555908203,
78.70732116699219,
25.373165130615234,
19.065534591674805,
2.7797865867614746,
14.478734970092773,
4.354736328125,
2.5210089683532715,
2.6700148582458496,
0.6075860857963562,
2.665254592895508,
1.173814058303833,
1.7027209997177124,
1.3020278215408325,
0.7013348937034607,
1.285564661026001,
1.3875261545181274,
0.8983238935470581,
0.5877137184143066,
0.8205947875976562,
0.521854043006897,
0.6408008933067322,
0.4418080747127533,
0.5785664319992065,
0.4362673759460449,
0.7655749917030334,
0.4951431155204773,
0.48009970784187317,
0.5542751550674438,
0.434432715177536,
0.4103431701660156,
0.4347221553325653,
0.4674955904483795,
0.36232373118400574,
0.3602845370769501,
0.3887442946434021,
0.49567872285842896,
0.37423327565193176,
0.594249427318573,
0.45169612765312195,
0.4800850450992584,
0.36560124158859253,
0.41520828008651733,
0.46888765692710876,
0.37928253412246704]}
Model Predictions¶
In [26]:
show_images_and_masks(data=train_ds, model=unet_model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step
In [27]:
show_images_and_masks(data=valid_ds, model=unet_model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 52ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 47ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 49ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step
In [28]:
show_images_and_masks(data=test_ds, model=unet_model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 48ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 47ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 48ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step
Evaluation¶
In [29]:
# U-Net test accuracy
results = unet_model.evaluate(test_ds, verbose=1)
# print(f'Test Loss: {test_loss:.4f}')
# print(f'Test Accuracy: {test_accuracy:.4f}')
# print(f'Test IoU: {test_iou:.4f}')
7/7 ━━━━━━━━━━━━━━━━━━━━ 3s 491ms/step - MeanIoU: 0.2065 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6806 - loss: 0.3672
Show train and validation loss¶
In [30]:
epochs = range(1, len(unet_model_history.history["loss"]) + 1)
loss = unet_model_history.history["loss"]
val_loss = unet_model_history.history["val_loss"]
plt.figure()
plt.plot(epochs, loss, "r", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
Out[30]:
<matplotlib.legend.Legend at 0x161b67cdc10>
In [32]:
# Extract the history object from the training
history = unet_model_history.history
# Define the list of metrics you want to plot
evaluation_metrics = ['PixelAccuracy', 'MeanIoU', 'dice_coeff']
# Create subplots for each metric
plt.figure(figsize=(18, 6))
for i, metric in enumerate(evaluation_metrics):
plt.subplot(1, len(evaluation_metrics), i + 1)
plt.plot(history[metric], label='Train')
plt.plot(history[f'val_{metric}'], label='Validation')
plt.title(f'{metric} over Epochs')
plt.xlabel('Epoch')
plt.ylabel(metric)
plt.legend()
plt.tight_layout()
plt.show()
SegNet (A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation)¶
In [ ]:
# #import tensorflow as tf
# from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, BatchNormalization, ReLU
# from tensorflow.keras.models import Model
# #from tensorflow.keras.optimizers import Adam
# #from tensorflow.keras.losses import BinaryCrossentropy
# #from tensorflow.keras.metrics import BinaryAccuracy, MeanIoU
# from tensorflow.keras import metrics
# class SegNet:
# def __init__(self, input_shape, num_classes=1, base_lr=0.001):
# self.input_shape = input_shape
# self.num_classes = num_classes
# self.base_lr = base_lr
# self.model = self.build_model()
# self.compile()
# def conv_block(self, x, filters, kernel_size=3, strides=1, padding='same'):
# x = Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=padding)(x)
# x = BatchNormalization()(x)
# x = ReLU()(x)
# return x
# def encoder_block(self, x, filters):
# x = self.conv_block(x, filters)
# x = self.conv_block(x, filters)
# p = MaxPooling2D(pool_size=2, strides=2, padding='same')(x)
# return x, p
# def decoder_block(self, x, skip, filters):
# x = Conv2DTranspose(filters, kernel_size=3, strides=2, padding='same')(x)
# x = tf.concat([x, skip], axis=-1)
# x = self.conv_block(x, filters)
# x = self.conv_block(x, filters)
# return x
# def build_model(self):
# inputs = Input(shape=self.input_shape)
# # Encoder
# e1, p1 = self.encoder_block(inputs, 64)
# e2, p2 = self.encoder_block(p1, 128)
# e3, p3 = self.encoder_block(p2, 256)
# e4, p4 = self.encoder_block(p3, 512)
# e5, p5 = self.encoder_block(p4, 512)
# # Decoder
# d5 = self.decoder_block(p5, e5, 512)
# d4 = self.decoder_block(d5, e4, 512)
# d3 = self.decoder_block(d4, e3, 256)
# d2 = self.decoder_block(d3, e2, 128)
# d1 = self.decoder_block(d2, e1, 64)
# # Final layer for binary segmentation
# outputs = Conv2D(self.num_classes, kernel_size=1, activation='sigmoid')(d1)
# model = Model(inputs=inputs, outputs=outputs)
# return model
# def compile(self):
# # Define metrics
# pixel_acc = metrics.Accuracy(name="PixelAccuracy")
# mean_iou = metrics.MeanIoU(num_classes=2, name="MeanIoU")
# # Learning Rate Schedule
# lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
# initial_learning_rate=self.base_lr,
# decay_steps=500,
# decay_rate=0.96,
# staircase=True
# )
# # Optimizer
# optimizer = optimizers.Adam(learning_rate=lr_schedule)
# # Compile the model
# self.model.compile(
# loss='binary_crossentropy',
# optimizer=optimizer,
# metrics=[pixel_acc, mean_iou, dice_coeff]
# )
# def summary(self):
# self.model.summary()
In [ ]:
# input_shape = (IMAGE_WIDTH, IMAGE_HEIGHT, N_IMAGE_CHANNELS)
# num_classes = 1 # Single channel output for binary segmentation
# # Create the SegNet model
# segnet = SegNet(input_shape, num_classes)
# # Compile the model
# segnet.compile()
# # Print model summary
# segnet.summary()
In [ ]:
# # Model Training
# segnet_model_history = segnet.model.fit(
# train_ds,
# validation_data = valid_ds,
# epochs = EPOCHS,
# callbacks = CALLBACKS,
# batch_size = BATCH_SIZE,
# )
In [ ]: